{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PaddlePaddle高层API\n",
    "\n",
    "# 第五天：基于seq2seq的对联生成\n",
    "\n",
    "## 1.一些基本概念的回顾\n",
    "\n",
    "\n",
    "### 1.1序列到序列问题\n",
    "\n",
    "这里，我们将根据上联，自动写下联。这是一个典型的序列到序列(sequence2sequence, seq2seq）建模的场景，编码器-解码器（Encoder-Decoder）框架是解决seq2seq问题的经典方法，它能够将一个任意长度的源序列转换成另一个任意长度的目标序列：编码阶段将整个源序列编码成一个向量，解码阶段通过最大化预测序列概率，从中解码出整个目标序列。编码和解码的过程通常都使用RNN实现。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/e9dde4be7d0142068c5c921a1ca6a227a49aad4a8751425faead42f0348f5e01\" width=\"500\" height=\"313\" ></center>\n",
    "<br><center>图1：encoder-decoder示意图</center></br>\n",
    "\n",
    "\n",
    "这里的Encoder采用LSTM，Decoder采用带有attention机制的LSTM。 \n",
    "\n",
    "我们将以对联的上联作为Encoder的输出，下联作为Decoder的输入，训练模型。\n",
    "\n",
    "### 1.2 attention机制\n",
    "\n",
    "Attention机制是很早就被提出来，但是在几年前才在NLP领域开始大放异彩的，尤其是谷歌大脑的Attention is all you need是非常有名的一篇文章（地址：https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf）\n",
    "\n",
    "文中对Attention的解释\n",
    "\n",
    "```\n",
    "An attention function can be described as mapping a query and a set of key-value pairs to an output,\n",
    "where the query, keys, values, and output are all vectors. The output is computed as a weighted sum\n",
    "of the values, where the weight assigned to each value is computed by a compatibility function of the\n",
    "query with the corresponding key.\n",
    "```\n",
    "\n",
    "Attention的详细讲解，也可以看看Aistudio上李宏毅老师的课程：机器学习进阶课节5：序列到序列\n",
    "\n",
    "https://aistudio.baidu.com/aistudio/education/group/info/1979\n",
    "\n",
    "其中讲到了在机器翻译中attention机制的应用：\n",
    "\n",
    "\n",
    "<img src='图片/attention1.png'>\n",
    "\n",
    "演算过程：$z_0$是attention的权重，乘以输入值\n",
    "\n",
    "<img src='图片/attention2.png'>\n",
    "\n",
    "\n",
    "## 2.代码中重要的API\n",
    "\n",
    "`paddle.nn.LSTM`的输入值\n",
    "\n",
    "|函数|意义|\n",
    "|--|--|\n",
    "input_size (int)|输入的大小。\n",
    "hidden_size (int)|隐藏状态大小。\n",
    "num_layers (int，可选)|网络层数。默认为1。\n",
    "direction (str，可选)|网络迭代方向，可设置为forward或bidirect（或bidirectional）。默认为forward。\n",
    "time_major (bool，可选)|指定input的第一个维度是否是time steps。默认为False。\n",
    "dropout (float，可选)|dropout概率，指的是出第一层外每层输入时的dropout概率。默认为0。\n",
    "weight_ih_attr (ParamAttr，可选)|weight_ih的参数。默认为None。\n",
    "weight_hh_attr (ParamAttr，可选)|weight_hh的参数。默认为None。\n",
    "bias_ih_attr (ParamAttr，可选)|bias_ih的参数。默认为None。\n",
    "bias_hh_attr (ParamAttr，可选)|bias_hh的参数。默认为None。\n",
    "\n",
    "\n",
    "这个和paddlenlp里面的LSTM有什么区别？\n",
    "\n",
    "\n",
    "`paddle.nn.RNN`的输入值\n",
    "\n",
    "|函数|功能|\n",
    "|--|--|\n",
    "inputs (Tensor)|输入（可以是多层嵌套的）。如果time_major为False，则Tensor的形状为[batch_size,time_steps,input_size]，如果time_major为True，则Tensor的形状为[time_steps,batch_size,input_size]，input_size为cell的input_size。\n",
    "initial_states (Tensor list tuple，可选)|输入cell的初始状态（可以是多层嵌套的），如果没有给出则会调用 cell.get_initial_states 生成初始状态。默认为None。\n",
    "sequence_length (Tensor，可选)|指定输入序列的长度，形状为[batch_size]，数据类型为int64或int32。在输入序列中所有time step不小于sequence_length的元素都会被当作填充元素处理（状态不再更新）。\n",
    "\n",
    "在https://github.com/PaddlePaddle/PaddleNLP/tree/develop/paddlenlp/seq2vec 中只有encoder的LSTM\n",
    "\n",
    "## 3.训练中注意的地方\n",
    "\n",
    "需要把前面的主网络模型中用于debug的print函数都注释掉，以免输出太多无用信息\n",
    "\n",
    "同时输出时选择`verbose=1`，可以使得输出结果更加美观\n",
    "\n",
    "```\n",
    "Epoch 1/20\n",
    "step 1373/1373 [==============================] - loss: 45.4127 - Perplexity: 305.7490 - ETA: 35s - 95ms/ste - loss: 44.8421 - Perplexity: 244.7160 - 96ms/step           \n",
    "save checkpoint at /home/aistudio/couplet_models/0\n",
    "```\n",
    "\n",
    "\n",
    "为了加快训练，分到高显存的V100显卡的时候，可以加大batch_size到512\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "!pip install --upgrade paddlenlp>=2.0.0b "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.0.0rc1'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import paddlenlp\n",
    "paddlenlp.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "import io\n",
    "import os\n",
    "\n",
    "from functools import partial\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import paddle\n",
    "import paddle.nn as nn\n",
    "import paddle.nn.functional as F\n",
    "from paddlenlp.data import Vocab, Pad\n",
    "from paddlenlp.metrics import Perplexity\n",
    "from paddlenlp.datasets import CoupletDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据部分\n",
    "\n",
    "## 数据集介绍\n",
    "\n",
    "采用开源的对联数据集[couplet-clean-dataset](https://github.com/v-zich/couplet-clean-dataset)，该数据集过滤了[\n",
    "couplet-dataset](https://github.com/wb14123/couplet-dataset)中的低俗、敏感内容。\n",
    "\n",
    "这个数据集包含70w多条训练样本，1000条验证样本和1000条测试样本。\n",
    "\n",
    "下面列出一些训练集中对联样例：\n",
    "\n",
    "上联：晚风摇树树还挺\t下联：晨露润花花更红\n",
    "\n",
    "上联：愿景天成无墨迹\t下联：万方乐奏有于阗\n",
    "\n",
    "上联：丹枫江冷人初去\t下联：绿柳堤新燕复来\n",
    "\n",
    "上联：闲来野钓人稀处\t下联：兴起高歌酒醉中"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载数据集\n",
    "\n",
    "`paddlenlp.datasets`中内置了多个常见数据集，包括这里的对联数据集`CoupletDataset`。\n",
    "\n",
    "<br>\n",
    "\n",
    "`paddlenlp.datasets`均继承`paddle.io.Dataset`，支持`paddle.io.Dataset`的所有功能：\n",
    "\n",
    "- 通过`len()`函数返回数据集长度，即样本数量。\n",
    "- 下标索引：通过下标索引[n]获取第n条样本。\n",
    "- 遍历数据集，获取所有样本。\n",
    "\n",
    "此外，`paddlenlp.datasets`，还支持如下操作：\n",
    "- 调用`get_datasets()`函数，传入list或者string，获取相对应的train_dataset、development_dataset、test_dataset等。其中train为训练集，用于模型训练； development为开发集，也称验证集validation_dataset，用于模型参数调优；test为测试集，用于评估算法的性能，但不会根据测试集上的表现再去调整模型或参数。\n",
    "- 调用`apply()`函数，对数据集进行指定操作。\n",
    "\n",
    "<br>\n",
    "\n",
    "这里的`CoupletDataset`数据集继承`TranslationDataset`，继承自`paddlenlp.datasets`，除以上通用用法外，还有一些个性设计：\n",
    "- 在`CoupletDataset class`中，还定义了`transform`函数，用于在每个句子的前后加上起始符`<s>`和结束符`</s>`，并将原始数据映射成id序列。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/d6c36cfd88eb4d0d87884f6c9cd47e466c2b562411394606be6683af08733045\" width=\"200\" height=\"200\" ></center>\n",
    "<br><center>图3：token-to-id示意图</center></br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 21421/21421 [00:01<00:00, 18848.18it/s]\n"
     ]
    }
   ],
   "source": [
    "train_ds, dev_ds, test_ds = CoupletDataset.get_datasets(['train', 'dev', 'test'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "来看看数据集有多大，长什么样："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "702594 999 1000\n",
      "([1, 447, 3, 509, 153, 153, 279, 1517, 2], [1, 816, 294, 378, 9, 9, 142, 32, 2])\n",
      "([1, 594, 185, 10, 71, 18, 158, 912, 2], [1, 14, 105, 107, 835, 20, 268, 3855, 2])\n",
      "([1, 335, 830, 68, 425, 4, 482, 246, 2], [1, 94, 51, 1115, 23, 141, 761, 17, 2])\n",
      "([1, 126, 17, 217, 802, 4, 1103, 118, 2], [1, 125, 205, 47, 55, 57, 78, 15, 2])\n",
      "([1, 1203, 228, 390, 10, 1921, 827, 474, 2], [1, 1699, 89, 426, 317, 314, 43, 374, 2])\n",
      "\n",
      "\n",
      "([1, 6, 201, 350, 54, 1156, 2], [1, 64, 522, 305, 543, 102, 2])\n",
      "([1, 168, 1402, 61, 270, 11, 195, 253, 2], [1, 435, 782, 1046, 36, 188, 1016, 56, 2])\n",
      "([1, 744, 185, 744, 6, 18, 452, 16, 1410, 2], [1, 286, 102, 286, 74, 20, 669, 280, 261, 2])\n",
      "([1, 2577, 496, 1133, 60, 107, 2], [1, 1533, 318, 625, 1401, 172, 2])\n",
      "([1, 163, 261, 6, 64, 116, 350, 253, 2], [1, 96, 579, 13, 463, 16, 774, 586, 2])\n"
     ]
    }
   ],
   "source": [
    "print (len(train_ds), len(test_ds), len(dev_ds))\n",
    "for i in range(5):\n",
    "    print (train_ds[i])\n",
    "\n",
    "print ('\\n')\n",
    "for i in range(5):\n",
    "    print (test_ds[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2 1 2\n"
     ]
    }
   ],
   "source": [
    "vocab, _ = CoupletDataset.get_vocab()\n",
    "trg_idx2word = vocab.idx_to_token\n",
    "vocab_size = len(vocab)\n",
    "\n",
    "pad_id = vocab[CoupletDataset.EOS_TOKEN]\n",
    "bos_id = vocab[CoupletDataset.BOS_TOKEN]\n",
    "eos_id = vocab[CoupletDataset.EOS_TOKEN]\n",
    "print (pad_id, bos_id, eos_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构造dataloder\n",
    "\n",
    "使用`paddle.io.DataLoader`来创建训练和预测时所需要的`DataLoader`对象。\n",
    "\n",
    "`paddle.io.DataLoader`返回一个迭代器，该迭代器根据`batch_sampler`指定的顺序迭代返回dataset数据。支持单进程或多进程加载数据，快！\n",
    "\n",
    "<br>\n",
    "\n",
    "接收如下重要参数：\n",
    "-  `batch_sampler`：批采样器实例，用于在`paddle.io.DataLoader` 中迭代式获取mini-batch的样本下标数组，数组长度与 batch_size 一致。\n",
    "-  `collate_fn`：指定如何将样本列表组合为mini-batch数据。传给它参数需要是一个`callable`对象，需要实现对组建的batch的处理逻辑，并返回每个batch的数据。在这里传入的是`prepare_input`函数，对产生的数据进行pad操作，并返回实际长度等。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PaddleNLP提供了许多NLP任务中，用于数据处理、组batch数据的相关API。\n",
    "\n",
    "| API                             | 简介                                       |\n",
    "| ------------------------------- | :----------------------------------------- |\n",
    "| `paddlenlp.data.Stack`          | 堆叠N个具有相同shape的输入数据来构建一个batch |\n",
    "| `paddlenlp.data.Pad`            | 将长度不同的多个句子padding到统一长度，取N个输入数据中的最大长度 |\n",
    "| `paddlenlp.data.Tuple`          | 将多个batchify函数包装在一起 |\n",
    "\n",
    "更多数据处理操作详见： [https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/data.md](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/data.md)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def create_data_loader(dataset):\n",
    "    data_loader = paddle.io.DataLoader(\n",
    "        dataset,\n",
    "        batch_sampler=None,\n",
    "        batch_size = batch_size,\n",
    "        collate_fn=partial(prepare_input, pad_id=pad_id))\n",
    "    return data_loader\n",
    "\n",
    "def prepare_input(insts, pad_id):\n",
    "    src, src_length = Pad(pad_val=pad_id, ret_length=True)([inst[0] for inst in insts])\n",
    "    tgt, tgt_length = Pad(pad_val=pad_id, ret_length=True)([inst[1] for inst in insts])\n",
    "    tgt_mask = (tgt[:, :-1] != pad_id).astype(paddle.get_default_dtype())\n",
    "    return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis], tgt_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "702594 1373 512\n",
      "5\n",
      "0 [512, 18]\n",
      "1 [512]\n",
      "2 [512, 17]\n",
      "3 [512, 17, 1]\n",
      "4 [512, 17]\n"
     ]
    }
   ],
   "source": [
    "# use_gpu = False\n",
    "use_gpu = True\n",
    "device = paddle.set_device(\"gpu\" if use_gpu else \"cpu\")\n",
    "\n",
    "batch_size = 512\n",
    "num_layers = 2\n",
    "dropout = 0.2\n",
    "hidden_size =256\n",
    "max_grad_norm = 5.0\n",
    "learning_rate = 0.001\n",
    "max_epoch = 20\n",
    "model_path = './couplet_models'\n",
    "log_freq = 1000\n",
    "\n",
    "# Define dataloader\n",
    "train_loader = create_data_loader(train_ds)\n",
    "test_loader = create_data_loader(test_ds)\n",
    "\n",
    "print(len(train_ds), len(train_loader), batch_size)\n",
    "# 702594 5490 128  共5490个batch\n",
    "\n",
    "for i in train_loader:\n",
    "    print (len(i))\n",
    "    for ind, each in enumerate(i):\n",
    "        print (ind, each.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型部分\n",
    "下图是带有Attention的Seq2Seq模型结构。下面我们分别定义网络的每个部分，最后构建Seq2Seq主网络。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/8a9dda0434a14fb2a0837702e5f2f1096346810702aa4a6ab1fa7dafe548add6\" width=\"600\" height=\"600\" ></center>\n",
    "<br><center>图5：带有attention机制的encoder-decoder原理示意图</center></br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义Encoder\n",
    "\n",
    "Encoder部分非常简单，可以直接利用PaddlePaddle2.0提供的RNN系列API的`nn.LSTM`。\n",
    "\n",
    "\n",
    "1. `nn.Embedding`：该接口用于构建 Embedding 的一个可调用对象，根据输入的size (vocab_size, embedding_dim)自动构造一个二维embedding矩阵，用于table-lookup。查表过程如下：\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/54276195f4ce44b9ace89c5153300782a744a98343454410898fcfe81333f131\" width=\"700\" height=\"600\" ></center>\n",
    "<br><center>图5：token-to-id & 查表获取向量示意图</center></br>\n",
    "\n",
    "2. `nn.LSTM`：提供序列，得到`encoder_output`和`encoder_state`。    \n",
    "参数：   \n",
    "- input_size (int) 输入的大小。\n",
    "- hidden_size (int) - 隐藏状态大小。\n",
    "- num_layers (int，可选) - 网络层数。默认为1。\n",
    "- direction (str，可选) - 网络迭代方向，可设置为forward或bidirect（或bidirectional）。默认为forward。\n",
    "- time_major (bool，可选) - 指定input的第一个维度是否是time steps。默认为False。\n",
    "- dropout (float，可选) - dropout概率，指的是出第一层外每层输入时的dropout概率。默认为0。\n",
    "\n",
    "[https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/layer/rnn/LSTM_cn.html](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/layer/rnn/LSTM_cn.html)\n",
    "\n",
    "\n",
    "输出:\n",
    "\n",
    "outputs (Tensor) - 输出，由前向和后向cell的输出拼接得到。如果time_major为True，则Tensor的形状为[time_steps,batch_size,num_directions * hidden_size]，如果time_major为False，则Tensor的形状为[batch_size,time_steps,num_directions * hidden_size]，当direction设置为bidirectional时，num_directions等于2，否则等于1。\n",
    "\n",
    "final_states (tuple) - 最终状态,一个包含h和c的元组。形状为[num_lauers * num_directions, batch_size, hidden_size],当direction设置为bidirectional时，num_directions等于2，否则等于1。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "class Seq2SeqEncoder(nn.Layer):\n",
    "    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers):\n",
    "        super(Seq2SeqEncoder, self).__init__()\n",
    "        self.embedder = nn.Embedding(vocab_size, embed_dim)\n",
    "        self.lstm = nn.LSTM(\n",
    "            input_size=embed_dim,\n",
    "            hidden_size=hidden_size,\n",
    "            num_layers=num_layers,\n",
    "            dropout=0.2 if num_layers > 1 else 0.)\n",
    "\n",
    "    def forward(self, sequence, sequence_length):\n",
    "        inputs = self.embedder(sequence)\n",
    "        encoder_output, encoder_state = self.lstm(\n",
    "            inputs, sequence_length=sequence_length)\n",
    "        \n",
    "        # encoder_output [128, 18, 256]  [batch_size,time_steps,hidden_size]\n",
    "        # encoder_state (tuple) - 最终状态,一个包含h和c的元组。 [2, 128, 256] [2, 128, 256] [num_lauers * num_directions, batch_size, hidden_size]\n",
    "        return encoder_output, encoder_state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义Decoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  定义AttentionLayer\n",
    "1. `nn.Linear`线性变换层传入2个参数   \n",
    "- in_features (int) – 线性变换层输入单元的数目。\n",
    "- out_features (int) – 线性变换层输出单元的数目。\n",
    "\n",
    "![](https://ai-studio-static-online.cdn.bcebos.com/4ce5851e5eab41af9733962d1782669dbef31a74540e468ab05b730ab2bb4ffc)\n",
    "\n",
    "2. `paddle.matmul`用于计算两个Tensor的乘积，遵循完整的广播规则，关于广播规则，请参考[广播 (broadcasting)](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/broadcasting_cn.html#cn-user-guide-broadcasting) 。 并且其行为与 numpy.matmul 一致。  \n",
    "- x (Tensor) : 输入变量，类型为 Tensor，数据类型为float32， float64。\n",
    "- y (Tensor) : 输入变量，类型为 Tensor，数据类型为float32， float64。\n",
    "- transpose_x (bool，可选) : 相乘前是否转置 x，默认值为False。\n",
    "- transpose_y (bool，可选) : 相乘前是否转置 y，默认值为False。\n",
    "\n",
    "<br>\n",
    "\n",
    "3. `paddle.unsqueeze`用于向输入Tensor的Shape中一个或多个位置（axis）插入尺寸为1的维度\n",
    "\n",
    "4. `paddle.add`逐元素相加算子，输入 x 与输入 y 逐元素相加，并将各个位置的输出元素保存到返回结果中。\n",
    "\n",
    "输入 x 与输入 y 必须可以广播为相同形状。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "class AttentionLayer(nn.Layer):\n",
    "    def __init__(self, hidden_size):\n",
    "        super(AttentionLayer, self).__init__()\n",
    "        self.input_proj = nn.Linear(hidden_size, hidden_size)\n",
    "        self.output_proj = nn.Linear(hidden_size + hidden_size, hidden_size)\n",
    "\n",
    "    def forward(self, hidden, encoder_output, encoder_padding_mask):\n",
    "        encoder_output = self.input_proj(encoder_output)\n",
    "        attn_scores = paddle.matmul(\n",
    "            paddle.unsqueeze(hidden, [1]), encoder_output, transpose_y=True)\n",
    "        # print('attention score', attn_scores.shape) #[128, 1, 18]\n",
    "\n",
    "        if encoder_padding_mask is not None:\n",
    "            attn_scores = paddle.add(attn_scores, encoder_padding_mask)\n",
    "\n",
    "        attn_scores = F.softmax(attn_scores)\n",
    "        attn_out = paddle.squeeze(\n",
    "            paddle.matmul(attn_scores, encoder_output), [1])\n",
    "        # print('1 attn_out', attn_out.shape) #[128, 256]\n",
    "\n",
    "        attn_out = paddle.concat([attn_out, hidden], 1)\n",
    "        # print('2 attn_out', attn_out.shape) #[128, 512]\n",
    "\n",
    "        attn_out = self.output_proj(attn_out)\n",
    "        # print('3 attn_out', attn_out.shape) #[128, 256]\n",
    "        return attn_out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义Seq2SeqDecoderCell\n",
    "由于Decoder部分是带有attention的LSTM，我们不能复用`nn.LSTM`，所以需要定义`Seq2SeqDecoderCell`\n",
    "\n",
    "1. `nn.LayerList` 用于保存子层列表，它包含的子层将被正确地注册和添加。列表中的子层可以像常规python列表一样被索引。这里添加了num_layers=2层lstm。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "class Seq2SeqDecoderCell(nn.RNNCellBase):\n",
    "    def __init__(self, num_layers, input_size, hidden_size):\n",
    "        super(Seq2SeqDecoderCell, self).__init__()\n",
    "        self.dropout = nn.Dropout(0.2)\n",
    "        self.lstm_cells = nn.LayerList([\n",
    "            nn.LSTMCell(\n",
    "                input_size=input_size + hidden_size if i == 0 else hidden_size,\n",
    "                hidden_size=hidden_size) for i in range(num_layers)\n",
    "        ])\n",
    "\n",
    "        self.attention_layer = AttentionLayer(hidden_size)\n",
    "    \n",
    "    def forward(self,\n",
    "                step_input,\n",
    "                states,\n",
    "                encoder_output,\n",
    "                encoder_padding_mask=None):\n",
    "        lstm_states, input_feed = states\n",
    "        new_lstm_states = []\n",
    "        step_input = paddle.concat([step_input, input_feed], 1)\n",
    "        for i, lstm_cell in enumerate(self.lstm_cells):\n",
    "            out, new_lstm_state = lstm_cell(step_input, lstm_states[i])\n",
    "            step_input = self.dropout(out)\n",
    "            new_lstm_states.append(new_lstm_state)\n",
    "        out = self.attention_layer(step_input, encoder_output,\n",
    "                                   encoder_padding_mask)\n",
    "        return out, [new_lstm_states, out]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义Seq2SeqDecoder\n",
    "有了`Seq2SeqDecoderCell`，就可以构建`Seq2SeqDecoder`了\n",
    "\n",
    "<br>\n",
    "\n",
    "1. `paddle.nn.RNN` 该OP是循环神经网络（RNN）的封装，将输入的Cell封装为一个循环神经网络。它能够重复执行 cell.forward() 直到遍历完input中的所有Tensor。\n",
    "- cell (RNNCellBase) - RNNCellBase类的一个实例。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "class Seq2SeqDecoder(nn.Layer):\n",
    "    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers):\n",
    "        super(Seq2SeqDecoder, self).__init__()\n",
    "        self.embedder = nn.Embedding(vocab_size, embed_dim)\n",
    "        self.lstm_attention = nn.RNN(\n",
    "            Seq2SeqDecoderCell(num_layers, embed_dim, hidden_size))\n",
    "        self.output_layer = nn.Linear(hidden_size, vocab_size)\n",
    "\n",
    "    def forward(self, trg, decoder_initial_states, encoder_output,\n",
    "                encoder_padding_mask):\n",
    "        inputs = self.embedder(trg)\n",
    "\n",
    "        decoder_output, _ = self.lstm_attention(\n",
    "            inputs,\n",
    "            initial_states=decoder_initial_states,\n",
    "            encoder_output=encoder_output,\n",
    "            encoder_padding_mask=encoder_padding_mask)\n",
    "        predict = self.output_layer(decoder_output)\n",
    "\n",
    "        return predict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建主网络Seq2SeqAttnModel\n",
    "Encoder和Decoder定义好之后，网络就可以构建起来了"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "class Seq2SeqAttnModel(nn.Layer):\n",
    "    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers,\n",
    "                 eos_id=1):\n",
    "        super(Seq2SeqAttnModel, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.eos_id = eos_id\n",
    "        self.num_layers = num_layers\n",
    "        self.INF = 1e9\n",
    "        self.encoder = Seq2SeqEncoder(vocab_size, embed_dim, hidden_size,\n",
    "                                      num_layers)\n",
    "        self.decoder = Seq2SeqDecoder(vocab_size, embed_dim, hidden_size,\n",
    "                                      num_layers)\n",
    "\n",
    "    def forward(self, src, src_length, trg):\n",
    "        # encoder_output 各时刻的输出h\n",
    "        # encoder_final_state 最后时刻的输出h，和记忆信号c\n",
    "        encoder_output, encoder_final_state = self.encoder(src, src_length)\n",
    "        # print('encoder_output shape', encoder_output.shape)  #  [128, 18, 256]  [batch_size,time_steps,hidden_size]\n",
    "        # print('encoder_final_states shape', encoder_final_state[0].shape, encoder_final_state[1].shape) #[2, 128, 256] [2, 128, 256] [num_lauers * num_directions, batch_size, hidden_size]\n",
    "\n",
    "        # Transfer shape of encoder_final_states to [num_layers, 2, batch_size, hidden_size]？？？\n",
    "        encoder_final_states = [\n",
    "            (encoder_final_state[0][i], encoder_final_state[1][i])\n",
    "            for i in range(self.num_layers)\n",
    "        ]\n",
    "        # print('encoder_final_states shape', encoder_final_states[0][0].shape, encoder_final_states[0][1].shape) #[128, 256] [128, 256]\n",
    "\n",
    "\n",
    "        # Construct decoder initial states: use input_feed and the shape is\n",
    "        # [[h,c] * num_layers, input_feed], consistent with Seq2SeqDecoderCell.states\n",
    "        decoder_initial_states = [\n",
    "            encoder_final_states,\n",
    "            self.decoder.lstm_attention.cell.get_initial_states(\n",
    "                batch_ref=encoder_output, shape=[self.hidden_size])\n",
    "        ]\n",
    "\n",
    "        # Build attention mask to avoid paying attention on padddings\n",
    "        src_mask = (src != self.eos_id).astype(paddle.get_default_dtype())\n",
    "        # print ('src_mask shape', src_mask.shape)  #[128, 18]\n",
    "        # print(src_mask[0, :])\n",
    "\n",
    "        encoder_padding_mask = (src_mask - 1.0) * self.INF\n",
    "        # print ('encoder_padding_mask', encoder_padding_mask.shape)  #[128, 18]\n",
    "        # print(encoder_padding_mask[0, :])\n",
    "\n",
    "        encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1])\n",
    "        # print('encoder_padding_mask', encoder_padding_mask.shape)  #[128, 1, 18]\n",
    "\n",
    "        predict = self.decoder(trg, decoder_initial_states, encoder_output,\n",
    "                               encoder_padding_mask)\n",
    "        # print('predict', predict.shape)   #[128, 17, 7931]\n",
    "\n",
    "        return predict\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义损失函数\n",
    "这里使用的是交叉熵损失函数，我们需要将padding位置的loss置为0，因此需要在损失函数中引入`trg_mask`参数，由于PaddlePaddle框架提供的`paddle.nn.CrossEntropyLoss`不能接受`trg_mask`参数，因此在这里需要重新定义："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "class CrossEntropyCriterion(nn.Layer):\n",
    "    def __init__(self):\n",
    "        super(CrossEntropyCriterion, self).__init__()\n",
    "\n",
    "    def forward(self, predict, label, trg_mask):\n",
    "        cost = F.softmax_with_cross_entropy(\n",
    "            logits=predict, label=label, soft_label=False)\n",
    "        cost = paddle.squeeze(cost, axis=[2])\n",
    "        masked_cost = cost * trg_mask\n",
    "        batch_mean_cost = paddle.mean(masked_cost, axis=[0])\n",
    "        seq_cost = paddle.sum(batch_mean_cost)\n",
    "\n",
    "        return seq_cost"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 执行过程\n",
    "## 训练过程\n",
    "使用高层API执行训练，需要调用`prepare`和`fit`函数。\n",
    "\n",
    "在`prepare`函数中，配置优化器、损失函数，以及评价指标。其中评价指标使用的是PaddleNLP提供的困惑度计算API `paddlenlp.metrics.Perplexity`。\n",
    "\n",
    "如果你安装了VisualDL，可以在fit中添加一个callbacks参数使用VisualDL观测你的训练过程，如下：\n",
    "\n",
    "``` python\n",
    "model.fit(train_data=train_loader,\n",
    "            epochs=max_epoch,\n",
    "            eval_freq=1,\n",
    "            save_freq=1,\n",
    "            save_dir=model_path,\n",
    "            log_freq=log_freq,\n",
    "            callbacks=[paddle.callbacks.VisualDL('./log')])\n",
    "```\n",
    "\n",
    "在这里，由于对联生成任务没有明确的评价指标，因此，可以在保存的多个模型中，通过人工评判生成结果选择最好的模型。\n",
    "\n",
    "本项目中，为了便于演示，已经将训练好的模型参数载入模型，并省略了训练过程。读者自己实验的时候，可以尝试自行修改超参数，调用下面被注释掉的`fit`函数，重新进行训练。\n",
    "\n",
    "如果读者想要在更短的时间内得到效果不错的模型，可以使用预训练模型技术，例如[《预训练模型ERNIE-GEN自动写诗》](https://aistudio.baidu.com/aistudio/projectdetail/1339888)项目为大家展示了如何利用预训练的生成模型进行训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The loss value printed in the log is the current step, and the metric is the average value of previous step.\n",
      "Epoch 1/20\n",
      "step 1373/1373 [==============================] - loss: 45.4127 - Perplexity: 305.7490 - ETA: 35s - 95ms/ste - loss: 44.8421 - Perplexity: 244.7160 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/0\n",
      "Epoch 2/20\n",
      "step 1373/1373 [==============================] - loss: 41.0244 - Perplexity: 102.2719 - ETA: 35s - 96ms/ste - loss: 40.8540 - Perplexity: 97.1830 - 96ms/step            \n",
      "save checkpoint at /home/aistudio/couplet_models/1\n",
      "Epoch 3/20\n",
      "step 1373/1373 [==============================] - loss: 38.5902 - Perplexity: 73.4031 - ETA: 35s - 95ms/st - loss: 38.2472 - Perplexity: 71.4694 - 95ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/2\n",
      "Epoch 4/20\n",
      "step 1373/1373 [==============================] - loss: 37.0881 - Perplexity: 60.6595 - ETA: 35s - 96ms/st - loss: 36.4279 - Perplexity: 59.6021 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/3\n",
      "Epoch 5/20\n",
      "step 1373/1373 [==============================] - loss: 36.0893 - Perplexity: 53.2231 - ETA: 35s - 96ms/st - loss: 35.0463 - Perplexity: 52.5910 - 97ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/4\n",
      "Epoch 6/20\n",
      "step 1373/1373 [==============================] - loss: 35.3010 - Perplexity: 48.4197 - ETA: 36s - 97ms/st - loss: 33.8170 - Perplexity: 47.9995 - 97ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/5\n",
      "Epoch 7/20\n",
      "step 1373/1373 [==============================] - loss: 34.7849 - Perplexity: 45.0536 - ETA: 35s - 96ms/st - loss: 32.7696 - Perplexity: 44.7681 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/6\n",
      "Epoch 8/20\n",
      "step 1373/1373 [==============================] - loss: 34.2624 - Perplexity: 42.5123 - ETA: 35s - 96ms/st - loss: 31.7726 - Perplexity: 42.3056 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/7\n",
      "Epoch 9/20\n",
      "step 1373/1373 [==============================] - loss: 33.8583 - Perplexity: 40.5478 - ETA: 35s - 96ms/st - loss: 30.6028 - Perplexity: 40.3765 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/8\n",
      "Epoch 10/20\n",
      "step 1373/1373 [==============================] - loss: 33.4531 - Perplexity: 38.9016 - ETA: 35s - 96ms/st - loss: 30.0247 - Perplexity: 38.7874 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/9\n",
      "Epoch 11/20\n",
      "step 1373/1373 [==============================] - loss: 33.1921 - Perplexity: 37.5176 - ETA: 35s - 96ms/st - loss: 29.2094 - Perplexity: 37.4202 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/10\n",
      "Epoch 12/20\n",
      "step 1373/1373 [==============================] - loss: 32.9125 - Perplexity: 36.3842 - ETA: 35s - 96ms/st - loss: 28.6912 - Perplexity: 36.3062 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/11\n",
      "Epoch 13/20\n",
      "step 1373/1373 [==============================] - loss: 32.7590 - Perplexity: 35.3640 - ETA: 35s - 96ms/st - loss: 27.6492 - Perplexity: 35.3034 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/12\n",
      "Epoch 14/20\n",
      "step 1373/1373 [==============================] - loss: 32.5031 - Perplexity: 34.4599 - ETA: 35s - 96ms/st - loss: 27.3593 - Perplexity: 34.4202 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/13\n",
      "Epoch 15/20\n",
      "step 1373/1373 [==============================] - loss: 32.3385 - Perplexity: 33.6792 - ETA: 35s - 96ms/st - loss: 26.7327 - Perplexity: 33.6405 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/14\n",
      "Epoch 16/20\n",
      "step 1373/1373 [==============================] - loss: 32.1561 - Perplexity: 32.9900 - ETA: 35s - 95ms/st - loss: 26.4542 - Perplexity: 32.9531 - 95ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/15\n",
      "Epoch 17/20\n",
      "step 1373/1373 [==============================] - loss: 32.0037 - Perplexity: 32.3352 - ETA: 35s - 96ms/st - loss: 25.7733 - Perplexity: 32.3058 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/16\n",
      "Epoch 18/20\n",
      "step 1373/1373 [==============================] - loss: 31.9462 - Perplexity: 31.7711 - ETA: 35s - 96ms/st - loss: 25.0613 - Perplexity: 31.7515 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/17\n",
      "Epoch 19/20\n",
      "step 1373/1373 [==============================] - loss: 31.8613 - Perplexity: 31.2189 - ETA: 36s - 97ms/st - loss: 25.1381 - Perplexity: 31.2192 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/18\n",
      "Epoch 20/20\n",
      "step 1373/1373 [==============================] - loss: 31.5394 - Perplexity: 30.7285 - ETA: 35s - 96ms/st - loss: 24.4705 - Perplexity: 30.7291 - 96ms/step           \n",
      "save checkpoint at /home/aistudio/couplet_models/19\n",
      "save checkpoint at /home/aistudio/couplet_models/final\n"
     ]
    }
   ],
   "source": [
    "model = paddle.Model(\n",
    "    Seq2SeqAttnModel(vocab_size, hidden_size, hidden_size,\n",
    "                        num_layers, pad_id))\n",
    "\n",
    "optimizer = paddle.optimizer.Adam(\n",
    "    learning_rate=learning_rate, parameters=model.parameters())\n",
    "ppl_metric = Perplexity()\n",
    "model.prepare(optimizer, CrossEntropyCriterion(), ppl_metric)\n",
    "\n",
    "model.fit(train_data=train_loader,\n",
    "            epochs=max_epoch,\n",
    "            # epochs=2,\n",
    "            eval_freq=1,\n",
    "            save_freq=1,\n",
    "            save_dir=model_path,\n",
    "            verbose=1,\n",
    "            log_freq=log_freq)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型预测\n",
    "\n",
    "### 定义预测网络Seq2SeqAttnInferModel\n",
    "预测网络继承上面的主网络`Seq2SeqAttnModel`，定义子类`Seq2SeqAttnInferModel`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "class Seq2SeqAttnInferModel(Seq2SeqAttnModel):\n",
    "    def __init__(self,\n",
    "                 vocab_size,\n",
    "                 embed_dim,\n",
    "                 hidden_size,\n",
    "                 num_layers,\n",
    "                 bos_id=0,\n",
    "                 eos_id=1,\n",
    "                 beam_size=4,\n",
    "                 max_out_len=256):\n",
    "        self.bos_id = bos_id\n",
    "        self.beam_size = beam_size\n",
    "        self.max_out_len = max_out_len\n",
    "        self.num_layers = num_layers\n",
    "        super(Seq2SeqAttnInferModel, self).__init__(\n",
    "            vocab_size, embed_dim, hidden_size, num_layers, eos_id)\n",
    "\n",
    "        # Dynamic decoder for inference\n",
    "        self.beam_search_decoder = nn.BeamSearchDecoder(\n",
    "            self.decoder.lstm_attention.cell,\n",
    "            start_token=bos_id,\n",
    "            end_token=eos_id,\n",
    "            beam_size=beam_size,\n",
    "            embedding_fn=self.decoder.embedder,\n",
    "            output_fn=self.decoder.output_layer)\n",
    "\n",
    "    def forward(self, src, src_length):\n",
    "        encoder_output, encoder_final_state = self.encoder(src, src_length)\n",
    "\n",
    "        encoder_final_state = [\n",
    "            (encoder_final_state[0][i], encoder_final_state[1][i])\n",
    "            for i in range(self.num_layers)\n",
    "        ]\n",
    "\n",
    "        # Initial decoder initial states\n",
    "        decoder_initial_states = [\n",
    "            encoder_final_state,\n",
    "            self.decoder.lstm_attention.cell.get_initial_states(\n",
    "                batch_ref=encoder_output, shape=[self.hidden_size])\n",
    "        ]\n",
    "        # Build attention mask to avoid paying attention on paddings\n",
    "        src_mask = (src != self.eos_id).astype(paddle.get_default_dtype())\n",
    "\n",
    "        encoder_padding_mask = (src_mask - 1.0) * self.INF\n",
    "        encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1])\n",
    "\n",
    "        # Tile the batch dimension with beam_size\n",
    "        encoder_output = nn.BeamSearchDecoder.tile_beam_merge_with_batch(\n",
    "            encoder_output, self.beam_size)\n",
    "        encoder_padding_mask = nn.BeamSearchDecoder.tile_beam_merge_with_batch(\n",
    "            encoder_padding_mask, self.beam_size)\n",
    "\n",
    "        # Dynamic decoding with beam search\n",
    "        seq_output, _ = nn.dynamic_decode(\n",
    "            decoder=self.beam_search_decoder,\n",
    "            inits=decoder_initial_states,\n",
    "            max_step_num=self.max_out_len,\n",
    "            encoder_output=encoder_output,\n",
    "            encoder_padding_mask=encoder_padding_mask)\n",
    "        return seq_output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 解码部分\n",
    "接下来对我们的任务选择beam search解码方式，可以指定beam_size为10。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):\n",
    "    \"\"\"\n",
    "    Post-process the decoded sequence.\n",
    "    \"\"\"\n",
    "    eos_pos = len(seq) - 1\n",
    "    for i, idx in enumerate(seq):\n",
    "        if idx == eos_idx:\n",
    "            eos_pos = i\n",
    "            break\n",
    "    seq = [\n",
    "        idx for idx in seq[:eos_pos + 1]\n",
    "        if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)\n",
    "    ]\n",
    "    return seq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "beam_size = 10\n",
    "'''\n",
    "init_from_ckpt = './couplet_models/19' # for test\n",
    "infer_output_file = './infer_output.txt'\n",
    "\n",
    "# test_loader, vocab_size, pad_id, bos_id, eos_id = create_data_loader(test_ds, batch_size)\n",
    "test_loader = create_data_loader(test_ds)\n",
    "vocab_size = batch_size\n",
    "pad_id = \n",
    "bos_id = 0\n",
    "eos_id = 1\n",
    "vocab, _ = CoupletDataset.get_vocab()\n",
    "trg_idx2word = vocab.idx_to_token\n",
    "'''\n",
    "model = paddle.Model(\n",
    "    Seq2SeqAttnInferModel(\n",
    "        vocab_size,\n",
    "        hidden_size,\n",
    "        hidden_size,\n",
    "        num_layers,\n",
    "        bos_id=bos_id,\n",
    "        eos_id=eos_id,\n",
    "        beam_size=beam_size,\n",
    "        max_out_len=256))\n",
    "\n",
    "model.prepare()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在预测之前，我们需要将训练好的模型参数load进预测网络，之后我们就可以根据对联的上联，生成对联的下联啦！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "model.load('couplet_models/model_18')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "test_ds = CoupletDataset.get_datasets(['test'])\n",
    "idx = 0\n",
    "for data in test_loader():\n",
    "    inputs = data[:2]\n",
    "    finished_seq = model.predict_batch(inputs=list(inputs))[0]\n",
    "    finished_seq = finished_seq[:, :, np.newaxis] if len(\n",
    "        finished_seq.shape) == 2 else finished_seq\n",
    "    finished_seq = np.transpose(finished_seq, [0, 2, 1])\n",
    "    for ins in finished_seq:\n",
    "        for beam in ins:\n",
    "            id_list = post_process_seq(beam, bos_id, eos_id)\n",
    "            word_list_l = [trg_idx2word[id] for id in test_ds[idx][0]][1:-1]\n",
    "            word_list_r = [trg_idx2word[id] for id in id_list]\n",
    "            sequence = \"上联: \"+\" \".join(word_list_l)+\"\\t下联: \"+\" \".join(word_list_r) + \"\\n\"\n",
    "            print(sequence)\n",
    "            idx += 1\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PaddleNLP更多教程\n",
    "- [使用seq2vec模块进行句子情感分析](https://aistudio.baidu.com/aistudio/projectdetail/1283423)\n",
    "- [使用预训练模型ERNIE优化情感分析](https://aistudio.baidu.com/aistudio/projectdetail/1294333)\n",
    "- [使用BiGRU-CRF模型完成快递单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/1317771)\n",
    "- [使用预训练模型ERNIE优化快递单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/1329361)\n",
    "- [使用预训练模型ERNIE-GEN实现智能写诗](https://aistudio.baidu.com/aistudio/projectdetail/1339888)\n",
    "- [使用TCN网络完成新冠疫情病例数预测](https://aistudio.baidu.com/aistudio/projectdetail/1290873)\n",
    "- [使用预训练模型完成阅读理解](https://aistudio.baidu.com/aistudio/projectdetail/1339612)\n",
    "- [自定义数据集实现文本多分类任务](https://aistudio.baidu.com/aistudio/projectdetail/1468469)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 加入交流群，一起学习吧\n",
    "\n",
    "现在就加入PaddleNLP的QQ技术交流群，一起交流NLP技术吧！\n",
    "\n",
    "<img src=\"https://ai-studio-static-online.cdn.bcebos.com/d953727af0c24a7c806ab529495f0904f22f809961be420b8c88cdf59b837394\" width=\"200\" height=\"250\" >"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:paddlepaddle]",
   "language": "python",
   "name": "conda-env-paddlepaddle-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
